#!/usr/bin/env python3

import argparse
import os
import sys
import torch

sys.path.insert(1, os.path.join(sys.path[0], ('..')))

import colorization.config as config
from colorization.util.argparse import nice_help_formatter


USAGE = \
"""run_training.py [-h|--help]
                       --config CONFIG
                       --data-dir DATA_DIR
                       --checkpoint-dir CHECKPOINT_DIR
                       --log-file LOG_FILE
                       --iterations N
                       --iterations-till-checkpoint N
                       [--default-config CONFIG]
                       [--init-model-checkpoint CHECKPOINT]
                       [--continue-training]
                       [--random-seed SEED]"""


if __name__ == '__main__':
    # parse command line arguments
    parser = argparse.ArgumentParser(formatter_class=nice_help_formatter(),
                                     usage=USAGE)

    parser.add_argument('--config',
                        required=True,
                        help="training configuration JSON file")

    parser.add_argument('--data-dir',
                        required=True,
                        help="training data directory")

    parser.add_argument('--checkpoint-dir',
                        required=True,
                        help="directory in which to save model checkpoints")

    parser.add_argument('--log-file',
                        required=True,
                        help="file to which to write log data")

    parser.add_argument('--iterations',
                        required=True,
                        metavar='N',
                        type=int,
                        help="number of iterations")

    parser.add_argument('--iterations-till-checkpoint',
                        required=True,
                        metavar='N',
                        type=int,
                        help="iterations between checkpoints")

    parser.add_argument('--default-config',
                        help="training fallback configuration JSON file")

    parser.add_argument('--init-model-checkpoint',
                        metavar='CHECKPOINT',
                        help="model checkpoint from which to initialize weights")

    parser.add_argument('--continue-training',
                        action='store_true',
                        help="pick up training from given checkpoint")

    parser.add_argument('--random-seed',
                        metavar='SEED',
                        type=int,
                        help="torch random seed")

    args = parser.parse_args()

    # set random seed
    if args.random_seed is not None:
        torch.manual_seed(args.random_seed)

    # load configuration file(s)
    cfg = config.get_config(args.config, args.default_config)

    config.modify_config(
        cfg, 'dataloader/params/dataset/params/root', args.data_dir)
    config.modify_config(
        cfg, 'model/params/log_config/handlers/file/filename', args.log_file)

    cfg = config.parse_config(cfg)

    dataloader = cfg['dataloader']
    model = cfg['model']

    # initialize model and run training

    def train(**kwargs):
        model.train(dataloader,
                    iterations=args.iterations,
                    iterations_till_checkpoint=args.iterations_till_checkpoint,
                    checkpoint_dir=args.checkpoint_dir,
                    **kwargs)

    if args.continue_training:
        if args.init_model_checkpoint is None:
            raise ValueError("no model checkpoint specified")

        train(checkpoint_init=args.init_model_checkpoint)
    else:
        if args.init_model_checkpoint is not None:
            model.load_checkpoint(args.init_model_checkpoint,
                                  load_optimizer=args.continue_training)

        train()
